# 预处理，pipline

import pandas as pd
import numpy as np
import jieba
import re
from sklearn.model_selection import train_test_split
from gensim.models.word2vec import Word2Vec
from nlu_model.util.pkl_impl import save_pkl, load_pkl
from nlu_model.cls.model_pytorch.model_train import TrainModelPipeline


# data preprocess
def loadfile():
    # 加载并预处理模型
    neg = pd.read_excel('./data/cls/shopping_reviews/neg.xls', header=None, index=None)
    pos = pd.read_excel('./data/cls/shopping_reviews/pos.xls', header=None, index=None)

    def cw(x): 
        punctuation = r"[\s+\.\!\/_,$%^*(+\"\']+|[+——！，。？、~@#￥%……&*（）：]+"
        x = re.sub(punctuation, "", x)

        return list(jieba.cut(x))
    pos['words'] = pos[0].apply(cw)
    neg['words'] = neg[0].apply(cw)

    y = np.concatenate((np.ones(len(pos)), np.zeros(len(neg))))

    x_train, x_test, y_train, y_test = train_test_split(
        np.concatenate((pos['words'], neg['words'])), y, test_size=0.2)
    
    return x_train, x_test, y_train, y_test

# pretrain_w2v
def pretrain_w2v(x_train):
    N_DIM = 300                         # word2vec的数量
    MIN_COUNT = 5                       # 保证出现的词数足够做才进入词典
    w2v_EPOCH = 15                      # w2v的训练迭代次数
    MAXLEN = 50                         # 句子最大长度
    # Initialize model and build vocab
    imdb_w2v = Word2Vec(size=N_DIM, min_count=MIN_COUNT)
    imdb_w2v.build_vocab(x_train)

    # Train the model over train_reviews (this may take several minutes)
    imdb_w2v.train(x_train, total_examples=len(x_train), epochs=w2v_EPOCH)

    # imdb_w2v.save('./data/ptm/shopping_reviews/w2v_model2020100601.pkl')
    print("model train done")

    # word2vec后处理
    n_symbols = len(imdb_w2v.wv.vocab.keys()) + 2
    embedding_weights = [[0 for i in range(N_DIM)] for i in range(n_symbols)]
    np.zeros((n_symbols, 300))
    idx = 1
    word2idx_dic = {}
    w2v_model_metric = []
    for w in imdb_w2v.wv.vocab.keys():
        embedding_weights[idx] = imdb_w2v[w]
        word2idx_dic[w] = idx
        idx = idx + 1

    # 留给未登录词的位置
    avg_weights = [0 for i in range(N_DIM)]
    for wd in word2idx_dic:
        avg_weights = [(avg_weights[idx]+embedding_weights[word2idx_dic[wd]][idx]) for idx in range(N_DIM)]
    avg_weights = [avg_weights[idx] / len(word2idx_dic) for idx in range(N_DIM)]
    embedding_weights[idx] = avg_weights
    word2idx_dic["<UNK>"] = idx

    # 留给pad的位置
    word2idx_dic["<PAD>"] = 0

    # 保存w2id词典
    save_pkl('./data/ptm/shopping_reviews/w2v_word2idx2020100601.pkl', word2idx_dic)
    with open('./data/ptm/shopping_reviews/w2v_word2idx2020100601.txt', 'w') as f:
        for item in word2idx_dic:
            f.write("%s\t%s\n" % (item, word2idx_dic[item]))

    # 保存词向量矩阵
    save_pkl("./data/ptm/shopping_reviews/w2v_model_metric_2020100601.pkl", embedding_weights)
    with open("./data/ptm/shopping_reviews/w2v_model_metric_2020100601.txt", "w") as f:
        for line in embedding_weights:
            f.write("%s/n" % (",".join([str(i) for i in line])))

    save_pkl("./data/ptm/shopping_reviews/w2v_model_conf_2020100601.pkl", [N_DIM, MIN_COUNT, w2v_EPOCH, MAXLEN])

    return embedding_weights,imdb_w2v, word2idx_dic



def word2idx(source_data, word2idx_dic, padding=-1):
    result_data = []
    for idx in range(len(source_data)):
        sentence = []
        for item in source_data[idx]:
            if item in word2idx_dic:
                sentence.append(word2idx_dic[item])
            else:
                sentence.append(len(word2idx_dic)-1)
        if padding >= 0:
            if len(sentence) > padding:
                sentence = sentence[:padding]
            else:
                while len(sentence) < padding:
                    sentence.append(0)
        result_data.append(np.array(sentence))
    return result_data

x_train, x_test, y_train, y_test = loadfile()
with open("./data/cls/shopping_reviews/train.txt", "w") as f:
    for idx in range(len(x_train)):
        f.write("%s\t%s\n" % (y_train[idx], " ".join(x_train[idx])))

with open("./data/cls/shopping_reviews/test.txt", "w") as f:
    for idx in range(len(x_test)):
        f.write("%s\t%s\n" % (y_test[idx], " ".join(x_test[idx])))
embedding_weights, imdb_w2v, word2idx_dic = pretrain_w2v(x_train)
x_train = word2idx(x_train, word2idx_dic, padding = 20)
x_test = word2idx(x_test, word2idx_dic, padding = 20)

textCNN_param = {
    "model_name": "TEXTCNN",
    'vocab_size': len(word2idx_dic),
    'embed_dim': 300,
    'class_num': 2,
    "kernel_num": 16,
    "kernel_size": [1, 2, 3, 4, 5],
    "dropout": 0.5,
    "learning_rate": 0.001,
    "pre_word_embeds": embedding_weights
} # auc:0.94836296, acc: 0.882729211087

textRCNN_param = {
    "model_name": "TEXTRCNN",
    'vocab_size': len(word2idx_dic),
    'embed_dim': 300,
    'class_num': 2,
    "kernel_num": 16,
    "kernel_size": [1, 2, 3, 4, 5],
    "dropout": 0.5,
    "pre_word_embeds": embedding_weights,
    "learning_rate": 0.001,
    "lstm_hidden": 128,
    "lstm_num_layers": 2,
    "pad_size":20
} # auc:0.94508814, acc: 0.8993129590144515


textRNN_param = {
    "model_name": "TEXTRNN",
    'vocab_size': len(word2idx_dic),
    'embed_dim': 300,
    'class_num': 2,
    "kernel_num": 16,
    "kernel_size": [1, 2, 3, 4, 5],
    "dropout": 0.5,
    "pre_word_embeds": embedding_weights,
    "learning_rate": 0.001,
    "lstm_hidden": 128,
    "lstm_num_layers": 2
}# auc:0.91433352，acc：0.8756218905472637

transformer_cls_param = {
    "model_name": "TRANSFORMERCLS",
    'vocab_size': len(word2idx_dic),
    'embed_dim': 300,
    'class_num': 2,
    "kernel_num": 16,
    "kernel_size": [1, 2, 3, 4, 5],
    "dropout": 0.5,
    "pre_word_embeds": embedding_weights,
    "learning_rate": 0.0001,
    "lstm_hidden": 128,
    "lstm_num_layers": 2,
    "pad_size":20
}  
# 2层auc: 0.93021430, acc: 0.868751480
# 3层auc：0.9244669, acc: 0.865197820421701

train_config = {
    "MODEL_CONF": transformer_cls_param,
    "batch_size": 64,
    "epoch":20
}
train_model_pipeline = TrainModelPipeline(train_config)
train_model_pipeline.call_train(x_train, y_train)
train_model_pipeline.call_evaluate(x_train, y_train)
train_model_pipeline.call_evaluate(x_test, y_test)
